{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6a64c02e",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "---\n",
    "description: Converting PyTorch Datasets to Avalanche Dataset\n",
    "---\n",
    "\n",
    "# Preamble: from PyTorch to Avalanche Datasets\n",
    "This short preamble will briefly go through the basic notions of Dataset offered natively by PyTorch. A solid grasp of these notions are needed to understand:\n",
    "1. How PyTorch data loading works in general\n",
    "2. How AvalancheDatasets differs from PyTorch Datasets\n",
    "3. How to instantiate Avalanche Datasets\n",
    "4. AvalancheDataset features\n",
    "\n",
    "## 📚 PyTorch Dataset: general definition\n",
    "\n",
    "In PyTorch, **a `Dataset` is a class** exposing two methods:\n",
    "- `__len__()`, which returns the amount of instances in the dataset (as an `int`). \n",
    "- `__getitem__(idx)`, which returns the data point at index `idx`.\n",
    "\n",
    "In other words, a Dataset instance is just an object for which, similarly to a list, one can simply:\n",
    "- Obtain its length using the Python `len(dataset)` function.\n",
    "- Obtain a single data point using the `x, y = dataset[idx]` syntax.\n",
    "\n",
    "The content of the dataset can be either loaded in memory when the dataset is instantiated (like the torchvision MNIST dataset does) or, for big datasets like ImageNet, the content is kept on disk, with the dataset keeping the list of files in an internal field. In this case, data is loaded from the storage on-the-fly when `__getitem__(idx)` is called. The way those things are managed is specific to each dataset implementation.\n",
    "\n",
    "### Quick note on the IterableDataset class\n",
    "A variation of the standard `Dataset` exist in PyTorch: the [IterableDataset](https://pytorch.org/docs/stable/data.html#iterable-style-datasets). When using an `IterableDataset`, one can load the data points in a sequential way only (by using a tape-alike approach). The `dataset[idx]` syntax and `len(dataset)` function are not allowed. **Avalanche does NOT support `IterableDataset`s.** You shouldn't worry about this because, realistically, you will never encounter such datasets (at least in torchvision). If you need `IterableDataset` let us know and we will consider adding support for them.\n",
    "\n",
    "\n",
    "## Wrapping PyTorch Datasets\n",
    "To create an `AvalancheDataset` starting from a PyTorch you only need to pass the original data to the constructor as follows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f9db09c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data.dataset import TensorDataset\n",
    "from avalanche.benchmarks.utils import AvalancheDataset\n",
    "\n",
    "# Create a dataset of 100 data points described by 22 features + 1 class label\n",
    "x_data = torch.rand(100, 22)\n",
    "y_data = torch.randint(0, 5, (100,))\n",
    "\n",
    "# Create the Dataset\n",
    "torch_data = TensorDataset(x_data, y_data)\n",
    "\n",
    "avl_data = AvalancheDataset(torch_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c2810b82",
   "metadata": {},
   "source": [
    "The dataset is equivalent to the original one:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "20464893",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([0.9110, 0.9209, 0.3404, 0.4981, 0.6954, 0.4509, 0.9966, 0.5070, 0.9475,\n",
      "        0.4813, 0.6436, 0.3680, 0.1872, 0.9233, 0.7774, 0.6226, 0.6519, 0.1818,\n",
      "        0.3813, 0.0060, 0.0607, 0.5219]), tensor(0))\n",
      "(tensor([0.9110, 0.9209, 0.3404, 0.4981, 0.6954, 0.4509, 0.9966, 0.5070, 0.9475,\n",
      "        0.4813, 0.6436, 0.3680, 0.1872, 0.9233, 0.7774, 0.6226, 0.6519, 0.1818,\n",
      "        0.3813, 0.0060, 0.0607, 0.5219]), tensor(0))\n"
     ]
    }
   ],
   "source": [
    "print(torch_data[0])\n",
    "print(avl_data[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cc30ca6",
   "metadata": {},
   "source": [
    "most of the time, you can also use one of the utility function in [benchmark utils](https://avalanche-api.continualai.org/en/latest/benchmarks.html#utils-data-loading-and-avalanchedataset) that also add attributes such as class and task labels to the dataset. For example, you can create a classification dataset using `make_classification_dataset`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eebb7d03",
   "metadata": {},
   "outputs": [],
   "source": [
    "from avalanche.benchmarks.utils import make_classification_dataset\n",
    "\n",
    "# first, we add targets to the dataset. This will be used by the AvalancheDataset\n",
    "# If possible, avalanche tries to extract the targets from the dataset.\n",
    "# most datasets in torchvision already have a targets field so you don't need this step.\n",
    "torch_data.targets = torch.randint(0, 5, (100,))\n",
    "tls = [0 for _ in range(100)] # one task label for each sample\n",
    "sup_data = make_classification_dataset(torch_data, task_labels=tls)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15bb7cba",
   "metadata": {},
   "source": [
    "## DataLoader\n",
    "Avalanche provides some [custom dataloaders](https://avalanche-api.continualai.org/en/latest/benchmarks.html#utils-data-loading-and-avalanchedataset) to sample in a task-balanced way or to balance the replay buffer and current data, but you can also use the standard pytorch `DataLoader`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6f6b78f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n",
      "Loaded minibatch of 10 instances\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "my_dataloader = DataLoader(avl_data, batch_size=10, shuffle=True)\n",
    "\n",
    "# Run one epoch\n",
    "for x_minibatch, y_minibatch in my_dataloader:\n",
    "    print('Loaded minibatch of', len(x_minibatch), 'instances')\n",
    "# Output: \"Loaded minibatch of 10 instances\" x10 times"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7e0a66f",
   "metadata": {},
   "source": [
    "## Dataset Operations: Concatenation and SubSampling\n",
    "While PyTorch provides two different classes for concatenation and subsampling (`ConcatDataset` and `Subset`), Avalanche implements them as dataset methods. These operation return a new dataset, leaving the original one unchanged."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c6c3e233",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200\n",
      "100\n",
      "50\n",
      "100\n"
     ]
    }
   ],
   "source": [
    "cat_data = avl_data.concat(avl_data)\n",
    "print(len(cat_data))  # 100 + 100 = 200\n",
    "print(len(avl_data))  # 100, original data stays the same\n",
    "\n",
    "sub_data = avl_data.subset(list(range(50)))\n",
    "print(len(sub_data))  # 50\n",
    "print(len(avl_data))  # 100, original data stays the same"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aaceed33",
   "metadata": {},
   "source": [
    "## Dataset Attributes\n",
    "AvalancheDataset allows to attach attributes to datasets. For example, classification datasets use this functionality to manage class and task labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "eba0f9bf",
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "int() argument must be a string, a bytes-like object or a real number, not 'slice'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "File \u001b[1;32m~\\Anaconda3\\envs\\avl_dev\\lib\\site-packages\\IPython\\core\\formatters.py:706\u001b[0m, in \u001b[0;36mPlainTextFormatter.__call__\u001b[1;34m(self, obj)\u001b[0m\n\u001b[0;32m    699\u001b[0m stream \u001b[38;5;241m=\u001b[39m StringIO()\n\u001b[0;32m    700\u001b[0m printer \u001b[38;5;241m=\u001b[39m pretty\u001b[38;5;241m.\u001b[39mRepresentationPrinter(stream, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mverbose,\n\u001b[0;32m    701\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_width, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnewline,\n\u001b[0;32m    702\u001b[0m     max_seq_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_seq_length,\n\u001b[0;32m    703\u001b[0m     singleton_pprinters\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msingleton_printers,\n\u001b[0;32m    704\u001b[0m     type_pprinters\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtype_printers,\n\u001b[0;32m    705\u001b[0m     deferred_pprinters\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdeferred_printers)\n\u001b[1;32m--> 706\u001b[0m \u001b[43mprinter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpretty\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    707\u001b[0m printer\u001b[38;5;241m.\u001b[39mflush()\n\u001b[0;32m    708\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m stream\u001b[38;5;241m.\u001b[39mgetvalue()\n",
      "File \u001b[1;32m~\\Anaconda3\\envs\\avl_dev\\lib\\site-packages\\IPython\\lib\\pretty.py:393\u001b[0m, in \u001b[0;36mRepresentationPrinter.pretty\u001b[1;34m(self, obj)\u001b[0m\n\u001b[0;32m    390\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01min\u001b[39;00m _get_mro(obj_class):\n\u001b[0;32m    391\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtype_pprinters:\n\u001b[0;32m    392\u001b[0m         \u001b[38;5;66;03m# printer registered in self.type_pprinters\u001b[39;00m\n\u001b[1;32m--> 393\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtype_pprinters\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mcls\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcycle\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    394\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    395\u001b[0m         \u001b[38;5;66;03m# deferred printer\u001b[39;00m\n\u001b[0;32m    396\u001b[0m         printer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_in_deferred_types(\u001b[38;5;28mcls\u001b[39m)\n",
      "File \u001b[1;32m~\\Anaconda3\\envs\\avl_dev\\lib\\site-packages\\IPython\\lib\\pretty.py:640\u001b[0m, in \u001b[0;36m_seq_pprinter_factory.<locals>.inner\u001b[1;34m(obj, p, cycle)\u001b[0m\n\u001b[0;32m    638\u001b[0m         p\u001b[38;5;241m.\u001b[39mtext(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m    639\u001b[0m         p\u001b[38;5;241m.\u001b[39mbreakable()\n\u001b[1;32m--> 640\u001b[0m     \u001b[43mp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpretty\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    641\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(obj) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(obj, \u001b[38;5;28mtuple\u001b[39m):\n\u001b[0;32m    642\u001b[0m     \u001b[38;5;66;03m# Special case for 1-item tuples.\u001b[39;00m\n\u001b[0;32m    643\u001b[0m     p\u001b[38;5;241m.\u001b[39mtext(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m,\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[1;32m~\\Anaconda3\\envs\\avl_dev\\lib\\site-packages\\IPython\\lib\\pretty.py:410\u001b[0m, in \u001b[0;36mRepresentationPrinter.pretty\u001b[1;34m(self, obj)\u001b[0m\n\u001b[0;32m    407\u001b[0m                         \u001b[38;5;28;01mreturn\u001b[39;00m meth(obj, \u001b[38;5;28mself\u001b[39m, cycle)\n\u001b[0;32m    408\u001b[0m                 \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mobject\u001b[39m \\\n\u001b[0;32m    409\u001b[0m                         \u001b[38;5;129;01mand\u001b[39;00m callable(\u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__dict__\u001b[39m\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m__repr__\u001b[39m\u001b[38;5;124m'\u001b[39m)):\n\u001b[1;32m--> 410\u001b[0m                     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_repr_pprint\u001b[49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcycle\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    412\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m _default_pprint(obj, \u001b[38;5;28mself\u001b[39m, cycle)\n\u001b[0;32m    413\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n",
      "File \u001b[1;32m~\\Anaconda3\\envs\\avl_dev\\lib\\site-packages\\IPython\\lib\\pretty.py:778\u001b[0m, in \u001b[0;36m_repr_pprint\u001b[1;34m(obj, p, cycle)\u001b[0m\n\u001b[0;32m    776\u001b[0m \u001b[38;5;124;03m\"\"\"A pprint that just redirects to the normal repr function.\"\"\"\u001b[39;00m\n\u001b[0;32m    777\u001b[0m \u001b[38;5;66;03m# Find newlines and replace them with p.break_()\u001b[39;00m\n\u001b[1;32m--> 778\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mrepr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    779\u001b[0m lines \u001b[38;5;241m=\u001b[39m output\u001b[38;5;241m.\u001b[39msplitlines()\n\u001b[0;32m    780\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m p\u001b[38;5;241m.\u001b[39mgroup():\n",
      "File \u001b[1;32mc:\\users\\w-32\\onedrive - university of pisa\\uni\\code_repo\\avalanche\\avalanche\\benchmarks\\utils\\data_attribute.py:61\u001b[0m, in \u001b[0;36mDataAttribute.__repr__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m     60\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__repr__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m---> 61\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[43m]\u001b[49m)\n",
      "File \u001b[1;32mc:\\users\\w-32\\onedrive - university of pisa\\uni\\code_repo\\avalanche\\avalanche\\benchmarks\\utils\\flat_data.py:191\u001b[0m, in \u001b[0;36mFlatData.__getitem__\u001b[1;34m(self, idx)\u001b[0m\n\u001b[0;32m    190\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, idx):\n\u001b[1;32m--> 191\u001b[0m     dataset_idx, idx \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_idx\u001b[49m\u001b[43m(\u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    192\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_datasets[dataset_idx][idx]\n",
      "File \u001b[1;32mc:\\users\\w-32\\onedrive - university of pisa\\uni\\code_repo\\avalanche\\avalanche\\benchmarks\\utils\\flat_data.py:188\u001b[0m, in \u001b[0;36mFlatData._get_idx\u001b[1;34m(self, idx)\u001b[0m\n\u001b[0;32m    186\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m    187\u001b[0m         idx \u001b[38;5;241m=\u001b[39m idx \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_cumulative_sizes[dataset_idx \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m]\n\u001b[1;32m--> 188\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m dataset_idx, \u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43midx\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[1;31mTypeError\u001b[0m: int() argument must be a string, a bytes-like object or a real number, not 'slice'"
     ]
    }
   ],
   "source": [
    "sup_data = make_classification_dataset(torch_data, task_labels=tls)\n",
    "\n",
    "sup_data.targets, sup_data.targets_task_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "209493da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e15b060b",
   "metadata": {},
   "source": [
    "## Transformations\n",
    "Most datasets from the *torchvision* libraries (as well as datasets found \"in the wild\") allow for a `transformation` function to be passed to the dataset constructor. The support for transformations is not mandatory for a dataset, but it is quite common to support them. The transformation is used to process the X value of a data point before returning it. This is used to normalize values, apply augmentations, etcetera.\n",
    "\n",
    "As explained in the mini *How-To*s, the `AvalancheDataset` class implements a very rich and powerful set of functionalities for managing transformations.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad256162",
   "metadata": {},
   "source": [
    "## Next steps\n",
    "With these notions in mind, you can start start your journey on understanding the functionalities offered by the AvalancheDatasets by going through the *Mini How-To*s.\n",
    "\n",
    "Please refer to the [list of the *Mini How-To*s regarding AvalancheDatasets](https://avalanche.continualai.org/how-tos/avalanchedataset) for a complete list. It is recommended to start with the **\"Creating AvalancheDatasets\"** *Mini How-To*."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b28db60e",
   "metadata": {},
   "source": [
    "## 🤝 Run it on Google Colab\n",
    "\n",
    "You can run _this chapter_ and play with it on Google Colaboratory by clicking here: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ContinualAI/avalanche/blob/master/notebooks/how-tos/avalanchedataset/preamble-pytorch-datasets.ipynb)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
